{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c002ee57",
   "metadata": {},
   "source": [
    "# MLP Inference \n",
    "\n",
    "In this example we will combine insight from the previous examples, and\n",
    "use TT-NN with PyTorch to perform a simple MLP inference task. This will\n",
    "demonstrate how to use TT-NN for tensor operations and model inference.\n",
    "\n",
    "Lets create the example file, `ttnn_mlp_inference_mnist.py`\n",
    "\n",
    "## Import Libraries\n",
    "\n",
    "In this script, a set of essential libraries are imported to perform inference on the MNIST digit classification task using a multi-layer perceptron (MLP) accelerated by Tenstorrent hardware. The torch library provides utilities for loading and manipulating data, while torchvision and its transforms submodule download the MNIST dataset and apply normalization and tensor conversion to the image inputs. The TT-NN library is the core interface for compiling and running neural network operations on Tenstorrent devices, including tensor creation, data layout transformation, and layer computation (e.g., linear, relu). The OS module checks for a pretrained weights file on disk. Finally, loguru provides clear and structured logging throughout the script, including loading status, prediction results, and final accuracy reporting. These imports enable a pipeline that loads data, runs inference on a custom backend, and logs the outcome efficiently."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8e8a746",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import numpy as np\n",
    "import ttnn\n",
    "import os\n",
    "from loguru import logger"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d23308c8",
   "metadata": {},
   "source": [
    "## Open the Device\n",
    "\n",
    "Create the device to run the program."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc24097c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Open Tenstorrent device\n",
    "device = ttnn.open_device(device_id=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f75cd6fe",
   "metadata": {},
   "source": [
    "## Load MNIST Test Data\n",
    "\n",
    "Load and convert MNIST 28x28 grayscale images to tensors for\n",
    "normalization. Subsequently, create a DataLoader to iterate\n",
    "through the dataset. This will allow us to perform inference on each\n",
    "image in the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6482d274",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load MNIST data\n",
    "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])\n",
    "testset = torchvision.datasets.MNIST(root=\"./data\", train=False, download=True, transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c14e13d0",
   "metadata": {},
   "source": [
    "## Load Pretrained MLP Weights\n",
    "\n",
    "Load the pretrained MLP weights from a file. Run the following script `train_and_export_mlp.py`\n",
    "Alternatively, if the weights file is not found, random weights values will be generated to test functionality, but expect poor prediction results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad17492d",
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.exists(\"mlp_mnist_weights.pt\"):\n",
    "    # Pretrained weights\n",
    "    weights = torch.load(\"mlp_mnist_weights.pt\")\n",
    "    W1 = ttnn.from_torch(weights[\"W1\"], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)\n",
    "    b1 = ttnn.from_torch(weights[\"b1\"], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)\n",
    "    W2 = ttnn.from_torch(weights[\"W2\"], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)\n",
    "    b2 = ttnn.from_torch(weights[\"b2\"], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)\n",
    "    W3 = ttnn.from_torch(weights[\"W3\"], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)\n",
    "    b3 = ttnn.from_torch(weights[\"b3\"], dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)\n",
    "    logger.info(\"Loaded pretrained weights from mlp_mnist_weights.pt\")\n",
    "else:\n",
    "    # Random weights for MLP - will not predict correctly\n",
    "    logger.warning(\"mlp_mnist_weights.pt not found, using random weights\")\n",
    "    W1 = ttnn.rand((128, 28 * 28), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)\n",
    "    b1 = ttnn.rand((128,), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)\n",
    "    W2 = ttnn.rand((64, 128), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)\n",
    "    b2 = ttnn.rand((64,), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)\n",
    "    W3 = ttnn.rand((10, 64), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)\n",
    "    b3 = ttnn.rand((10,), dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "396ded06",
   "metadata": {},
   "source": [
    "## Accuracy Tracking, Inference, Loop, and Image Flattening\n",
    "\n",
    "The following code snippet performs inference on the first five test samples from the MNIST dataset using a multi-layer perceptron (MLP) executed on Tenstorrent hardware via the TT-NN API. It initializes counters for tracking the number of correct predictions. For each sample, the input image is flattened into a 1D vector and converted from a PyTorch tensor to a TT-NN tensor with bfloat16 precision and TILE_LAYOUT for efficient execution. The tensor is then passed sequentially through three fully connected layers: each of the first two layers applies a linear transformation followed by a ReLU activation, while the final layer produces raw logits for the 10 output classes (digits 0–9). For each layer, the weights are transposed and the biases are reshaped to match TT-NN's expected input dimensions. After computing the final output, it is converted back to a PyTorch tensor, and the class with the highest activation is selected as the predicted label. The prediction is compared with the true label to update the accuracy counters, and the result is logged. Once all five samples are processed, the script logs the overall prediction accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f4fa175",
   "metadata": {},
   "outputs": [],
   "source": [
    "correct = 0\n",
    "total = 0\n",
    "\n",
    "for i, (image, label) in enumerate(testloader):\n",
    "    if i >= 5:\n",
    "        break\n",
    "\n",
    "    image = image.view(1, -1).to(torch.float32)\n",
    "    \n",
    "    # Convert to TT-NN Tensor\n",
    "    # Convert the PyTorch tensor to TT-NN format with bfloat16 data type and\n",
    "    # TILE\\_LAYOUT. This is necessary for efficient computation on the\n",
    "    # Tenstorrent device.\n",
    "    image_tt = ttnn.from_torch(image, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)\n",
    "    \n",
    "    # Layer 1\n",
    "    # Transposed weights are used to match TT-NN's expected shape. Bias\n",
    "    # reshaped to 1x128 for broadcasting, and compute output 1.\n",
    "    W1_final = ttnn.transpose(W1, -2, -1)\n",
    "    b1_final = ttnn.reshape(b1, [1, -1])\n",
    "    out1 = ttnn.linear(image_tt, W1_final, bias=b1_final)\n",
    "    out1 = ttnn.relu(out1)\n",
    "    \n",
    "    # Layer 2\n",
    "    # Same pattern as Layer 1, but with different weights and biases.\n",
    "    W2_final = ttnn.transpose(W2, -2, -1)\n",
    "    b2_final = ttnn.reshape(b2, [1, -1])\n",
    "    out2 = ttnn.linear(out1, W2_final, bias=b2_final)\n",
    "    out2 = ttnn.relu(out2)\n",
    "    \n",
    "    # Layer 3\n",
    "    # Final layer with 10 output (for digits 0-9). No ReLU activation here, as\n",
    "    # this is the output layer.\n",
    "    W3_final = ttnn.transpose(W3, -2, -1)\n",
    "    b3_final = ttnn.reshape(b3, [1, -1])\n",
    "    out3 = ttnn.linear(out2, W3_final, bias=b3_final)\n",
    "    \n",
    "    # Convert result back to torch\n",
    "    prediction = ttnn.to_torch(out3)\n",
    "    predicted_label = torch.argmax(prediction, dim=1).item()\n",
    "    \n",
    "    correct += predicted_label == label.item()\n",
    "    total += 1\n",
    "    \n",
    "    logger.info(f\"Sample {i+1}: Predicted={predicted_label}, Actual={label.item()}\")\n",
    "    \n",
    "logger.info(f\"\\nTT-NN MLP Inference Accuracy: {correct}/{total} = {100.0 * correct / total:.2f}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65edabfc",
   "metadata": {},
   "source": [
    "## Close the Device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "350adfa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "ttnn.close_device(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25856722",
   "metadata": {},
   "source": [
    "## Full Example and Output\n",
    "\n",
    "Lets put everything together in a complete example that can be run directly. \n",
    "\n",
    "[ttnn_mlp_inference_mnist.py](https://github.com/tenstorrent/tt-metal/tree/main/ttnn/tutorials/basic_python/ttnn_mlp_inference_mnist.py)\n",
    "\n",
    "Run the following script to generate output:\n",
    "\n",
    "``` console\n",
    "$ python3 $TT_METAL_HOME/ttnn/tutorials/basic_python/ttnn_mlp_inference_mnist.py\n",
    "2025-07-07 13:03:41.990 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:03:41.992 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:03:41.998 | info     |          Device | Opening user mode device driver (tt_cluster.cpp:190)\n",
    "2025-07-07 13:03:41.998 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:03:41.999 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:03:42.006 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:03:42.007 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:03:42.013 | info     |   SiliconDriver | Harvesting mask for chip 0 is 0x100 (NOC0: 0x100, simulated harvesting mask: 0x0). (cluster.cpp:282)\n",
    "2025-07-07 13:03:42.110 | info     |   SiliconDriver | Opened PCI device 7; KMD version: 1.34.0; API: 1; IOMMU: disabled (pci_device.cpp:198)\n",
    "2025-07-07 13:03:42.172 | info     |   SiliconDriver | Opening local chip ids/pci ids: {0}/[7] and remote chip ids {} (cluster.cpp:147)\n",
    "2025-07-07 13:03:42.182 | info     |   SiliconDriver | Software version 6.0.0, Ethernet FW version 6.14.0 (Device 0) (cluster.cpp:1039)\n",
    "2025-07-07 13:03:42.268 | info     |           Metal | AI CLK for device 0 is:   1000 MHz (metal_context.cpp:128)\n",
    "2025-07-07 13:03:42.886 | info     |           Metal | Initializing device 0. Program cache is enabled (device.cpp:428)\n",
    "2025-07-07 13:03:42.888 | warning  |           Metal | Unable to bind worker thread to CPU Core. May see performance degradation. Error Code: 22 (hardware_command_queue.cpp:74)\n",
    "2025-07-07 13:03:44.852 | INFO     | __main__:main:32 - Loaded pretrained weights from mlp_mnist_weights.pt\n",
    "2025-07-07 13:03:48.677 | INFO     | __main__:main:87 - Sample 1: Predicted=7, Actual=7\n",
    "2025-07-07 13:03:48.682 | INFO     | __main__:main:87 - Sample 2: Predicted=2, Actual=2\n",
    "2025-07-07 13:03:48.686 | INFO     | __main__:main:87 - Sample 3: Predicted=1, Actual=1\n",
    "2025-07-07 13:03:48.690 | INFO     | __main__:main:87 - Sample 4: Predicted=0, Actual=0\n",
    "2025-07-07 13:03:48.695 | INFO     | __main__:main:87 - Sample 5: Predicted=4, Actual=4\n",
    "2025-07-07 13:03:48.695 | INFO     | __main__:main:89 - \n",
    "TT-NN MLP Inference Accuracy: 5/5 = 100.00%\n",
    "2025-07-07 13:03:48.695 | info     |           Metal | Closing mesh device 1 (mesh_device.cpp:488)\n",
    "2025-07-07 13:03:48.696 | info     |           Metal | Closing mesh device 0 (mesh_device.cpp:488)\n",
    "2025-07-07 13:03:48.696 | info     |           Metal | Closing device 0 (device.cpp:468)\n",
    "2025-07-07 13:03:48.696 | info     |           Metal | Disabling and clearing program cache on device 0 (device.cpp:783)\n",
    "2025-07-07 13:03:48.697 | info     |           Metal | Closing mesh device 1 (mesh_device.cpp:488)\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
